#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# pytype: skip-file

import datetime
import decimal
import io
import json
import logging
import math
import re
import unittest
from typing import Optional
from typing import Sequence

import fastavro
import mock
import numpy as np
import pytz
from parameterized import parameterized

import apache_beam as beam
from apache_beam.io.gcp import resource_identifiers
from apache_beam.io.gcp.bigquery_tools import JSON_COMPLIANCE_ERROR
from apache_beam.io.gcp.bigquery_tools import AvroRowWriter
from apache_beam.io.gcp.bigquery_tools import BigQueryJobTypes
from apache_beam.io.gcp.bigquery_tools import JsonRowWriter
from apache_beam.io.gcp.bigquery_tools import RowAsDictJsonCoder
from apache_beam.io.gcp.bigquery_tools import beam_row_from_dict
from apache_beam.io.gcp.bigquery_tools import check_schema_equal
from apache_beam.io.gcp.bigquery_tools import generate_bq_job_name
from apache_beam.io.gcp.bigquery_tools import get_beam_typehints_from_tableschema
from apache_beam.io.gcp.bigquery_tools import parse_table_reference
from apache_beam.io.gcp.bigquery_tools import parse_table_schema_from_json
from apache_beam.io.gcp.internal.clients import bigquery
from apache_beam.metrics import monitoring_infos
from apache_beam.metrics.execution import MetricsEnvironment
from apache_beam.options.value_provider import StaticValueProvider
from apache_beam.typehints.row_type import RowTypeConstraint
from apache_beam.utils.timestamp import Timestamp

# Protect against environments where bigquery library is not available.
# pylint: disable=wrong-import-order, wrong-import-position
try:
  from apitools.base.py.exceptions import HttpError
  from apitools.base.py.exceptions import HttpForbiddenError
  from google.api_core.exceptions import ClientError
  from google.api_core.exceptions import DeadlineExceeded
  from google.api_core.exceptions import InternalServerError
except ImportError:
  ClientError = None
  DeadlineExceeded = None
  HttpError = None
  HttpForbiddenError = None
  InternalServerError = None
  google = None
# pylint: enable=wrong-import-order, wrong-import-position


@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class TestTableSchemaParser(unittest.TestCase):
  def test_parse_table_schema_from_json(self):
    string_field = bigquery.TableFieldSchema(
        name='s', type='STRING', mode='NULLABLE', description='s description')
    number_field = bigquery.TableFieldSchema(
        name='n', type='INTEGER', mode='REQUIRED', description='n description')
    record_field = bigquery.TableFieldSchema(
        name='r',
        type='RECORD',
        mode='REQUIRED',
        description='r description',
        fields=[string_field, number_field])
    expected_schema = bigquery.TableSchema(fields=[record_field])
    json_str = json.dumps({
        'fields': [{
            'name': 'r',
            'type': 'RECORD',
            'mode': 'REQUIRED',
            'description': 'r description',
            'fields': [{
                'name': 's',
                'type': 'STRING',
                'mode': 'NULLABLE',
                'description': 's description'
            },
                       {
                           'name': 'n',
                           'type': 'INTEGER',
                           'mode': 'REQUIRED',
                           'description': 'n description'
                       }]
        }]
    })
    self.assertEqual(parse_table_schema_from_json(json_str), expected_schema)


@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class TestTableReferenceParser(unittest.TestCase):
  def test_calling_with_table_reference(self):
    table_ref = bigquery.TableReference()
    table_ref.projectId = 'test_project'
    table_ref.datasetId = 'test_dataset'
    table_ref.tableId = 'test_table'
    parsed_ref = parse_table_reference(table_ref)
    self.assertEqual(table_ref, parsed_ref)
    self.assertIsNot(table_ref, parsed_ref)

  def test_calling_with_callable(self):
    callable_ref = lambda: 'foo'
    parsed_ref = parse_table_reference(callable_ref)
    self.assertIs(callable_ref, parsed_ref)

  def test_calling_with_value_provider(self):
    value_provider_ref = StaticValueProvider(str, 'test_dataset.test_table')
    parsed_ref = parse_table_reference(value_provider_ref)
    self.assertIs(value_provider_ref, parsed_ref)

  @parameterized.expand([
      ('project:dataset.test_table', 'project', 'dataset', 'test_table'),
      ('project:dataset.test-table', 'project', 'dataset', 'test-table'),
      ('project:dataset.test- table', 'project', 'dataset', 'test- table'),
      ('project.dataset. test_table', 'project', 'dataset', ' test_table'),
      ('project.dataset.test$table', 'project', 'dataset', 'test$table'),
  ])
  def test_calling_with_fully_qualified_table_ref(
      self,
      fully_qualified_table: str,
      project_id: str,
      dataset_id: str,
      table_id: str,
  ):
    parsed_ref = parse_table_reference(fully_qualified_table)
    self.assertIsInstance(parsed_ref, bigquery.TableReference)
    self.assertEqual(parsed_ref.projectId, project_id)
    self.assertEqual(parsed_ref.datasetId, dataset_id)
    self.assertEqual(parsed_ref.tableId, table_id)

  def test_calling_with_partially_qualified_table_ref(self):
    datasetId = 'test_dataset'
    tableId = 'test_table'
    partially_qualified_table = '{}.{}'.format(datasetId, tableId)
    parsed_ref = parse_table_reference(partially_qualified_table)
    self.assertIsInstance(parsed_ref, bigquery.TableReference)
    self.assertEqual(parsed_ref.datasetId, datasetId)
    self.assertEqual(parsed_ref.tableId, tableId)

  def test_calling_with_insufficient_table_ref(self):
    table = 'test_table'
    self.assertRaises(ValueError, parse_table_reference, table)

  def test_calling_with_all_arguments(self):
    projectId = 'test_project'
    datasetId = 'test_dataset'
    tableId = 'test_table'
    parsed_ref = parse_table_reference(
        tableId, dataset=datasetId, project=projectId)
    self.assertIsInstance(parsed_ref, bigquery.TableReference)
    self.assertEqual(parsed_ref.projectId, projectId)
    self.assertEqual(parsed_ref.datasetId, datasetId)
    self.assertEqual(parsed_ref.tableId, tableId)


@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class TestBigQueryWrapper(unittest.TestCase):
  def test_delete_non_existing_dataset(self):
    client = mock.Mock()
    client.datasets.Delete.side_effect = HttpError(
        response={'status': '404'}, url='', content='')
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
    wrapper._delete_dataset('', '')
    self.assertTrue(client.datasets.Delete.called)

  @mock.patch('time.sleep', return_value=None)
  def test_delete_dataset_retries_fail(self, patched_time_sleep):
    client = mock.Mock()
    client.datasets.Delete.side_effect = ValueError("Cannot delete")
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
    with self.assertRaises(ValueError):
      wrapper._delete_dataset('', '')
    self.assertEqual(
        beam.io.gcp.bigquery_tools.MAX_RETRIES + 1,
        client.datasets.Delete.call_count)
    self.assertTrue(client.datasets.Delete.called)

  def test_delete_non_existing_table(self):
    client = mock.Mock()
    client.tables.Delete.side_effect = HttpError(
        response={'status': '404'}, url='', content='')
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
    wrapper._delete_table('', '', '')
    self.assertTrue(client.tables.Delete.called)

  @mock.patch('time.sleep', return_value=None)
  def test_delete_table_retries_fail(self, patched_time_sleep):
    client = mock.Mock()
    client.tables.Delete.side_effect = ValueError("Cannot delete")
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
    with self.assertRaises(ValueError):
      wrapper._delete_table('', '', '')
    self.assertTrue(client.tables.Delete.called)

  @mock.patch('time.sleep', return_value=None)
  def test_delete_dataset_retries_for_timeouts(self, patched_time_sleep):
    client = mock.Mock()
    client.datasets.Delete.side_effect = [
        HttpError(response={'status': '408'}, url='', content=''),
        bigquery.BigqueryDatasetsDeleteResponse()
    ]
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
    wrapper._delete_dataset('', '')
    self.assertTrue(client.datasets.Delete.called)

  # the function _insert_all_rows() in the wrapper calls google.cloud.bigquery,
  # so we have to skip that when this library is not accessible
  @unittest.skipIf(
      beam.io.gcp.bigquery_tools.gcp_bigquery is None,
      "bigquery library not available in this env")
  @mock.patch('time.sleep', return_value=None)
  @mock.patch(
      'apitools.base.py.base_api._SkipGetCredentials', return_value=True)
  @mock.patch('google.cloud._http.JSONConnection.http')
  def test_user_agent_insert_all(
      self, http_mock, patched_skip_get_credentials, patched_sleep):
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper()
    try:
      wrapper._insert_all_rows('p', 'd', 't', [{'name': 'any'}], None)
    except:  # pylint: disable=bare-except
      # Ignore errors. The errors come from the fact that we did not mock
      # the response from the API, so the overall insert_all_rows call fails
      # soon after the BQ API is called.
      pass
    call = http_mock.request.mock_calls[-2]
    self.assertIn('apache-beam-', call[2]['headers']['User-Agent'])

  # the function create_temporary_dataset() in the wrapper does not call
  # google.cloud.bigquery, so it is fine to just mock it
  @mock.patch(
      'apache_beam.io.gcp.bigquery_tools.gcp_bigquery',
      return_value=mock.Mock())
  @mock.patch(
      'apitools.base.py.base_api._SkipGetCredentials', return_value=True)
  @mock.patch('time.sleep', return_value=None)
  def test_user_agent_create_temporary_dataset(
      self, sleep_mock, skip_get_credentials_mock, gcp_bigquery_mock):
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper()
    request_mock = mock.Mock()
    wrapper.client._http.request = request_mock
    try:
      wrapper.create_temporary_dataset('project-id', 'location')
    except:  # pylint: disable=bare-except
      # Ignore errors. The errors come from the fact that we did not mock
      # the response from the API, so the overall create_dataset call fails
      # soon after the BQ API is called.
      pass
    call = request_mock.mock_calls[-1]
    self.assertIn('apache-beam-', call[2]['headers']['user-agent'])

  @mock.patch('time.sleep', return_value=None)
  def test_delete_table_retries_for_timeouts(self, patched_time_sleep):
    client = mock.Mock()
    client.tables.Delete.side_effect = [
        HttpError(response={'status': '408'}, url='', content=''),
        bigquery.BigqueryTablesDeleteResponse()
    ]
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
    wrapper._delete_table('', '', '')
    self.assertTrue(client.tables.Delete.called)

  @mock.patch('time.sleep', return_value=None)
  def test_temporary_dataset_is_unique(self, patched_time_sleep):
    client = mock.Mock()
    client.datasets.Get.return_value = bigquery.Dataset(
        datasetReference=bigquery.DatasetReference(
            projectId='project-id', datasetId='dataset_id'))
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
    with self.assertRaises(RuntimeError):
      wrapper.create_temporary_dataset('project-id', 'location')
    self.assertTrue(client.datasets.Get.called)

  def test_get_or_create_dataset_created(self):
    client = mock.Mock()
    client.datasets.Get.side_effect = HttpError(
        response={'status': '404'}, url='', content='')
    client.datasets.Insert.return_value = bigquery.Dataset(
        datasetReference=bigquery.DatasetReference(
            projectId='project-id', datasetId='dataset_id'))
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
    new_dataset = wrapper.get_or_create_dataset('project-id', 'dataset_id')
    self.assertEqual(new_dataset.datasetReference.datasetId, 'dataset_id')

  def test_create_temporary_dataset_with_kms_key(self):
    kms_key = (
        'projects/my-project/locations/global/keyRings/my-kr/'
        'cryptoKeys/my-key')
    client = mock.Mock()
    client.datasets.Get.side_effect = HttpError(
        response={'status': '404'}, url='', content='')

    client.datasets.Insert.return_value = bigquery.Dataset(
        datasetReference=bigquery.DatasetReference(
            projectId='project-id', datasetId='temp_dataset'))
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

    try:
      wrapper.create_temporary_dataset(
          'project-id', 'location', kms_key=kms_key)
    except Exception:
      pass

    args, _ = client.datasets.Insert.call_args
    insert_request = args[0]  # BigqueryDatasetsInsertRequest
    inserted_dataset = insert_request.dataset  # Actual Dataset object

    # Assertions
    self.assertIsNotNone(inserted_dataset.defaultEncryptionConfiguration)
    self.assertEqual(
        inserted_dataset.defaultEncryptionConfiguration.kmsKeyName, kms_key)

  def test_get_or_create_dataset_fetched(self):
    client = mock.Mock()
    client.datasets.Get.return_value = bigquery.Dataset(
        datasetReference=bigquery.DatasetReference(
            projectId='project-id', datasetId='dataset_id'))
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
    new_dataset = wrapper.get_or_create_dataset('project-id', 'dataset_id')
    self.assertEqual(new_dataset.datasetReference.datasetId, 'dataset_id')

  def test_get_or_create_table(self):
    client = mock.Mock()
    client.tables.Insert.return_value = 'table_id'
    client.tables.Get.side_effect = [None, 'table_id']
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
    new_table = wrapper.get_or_create_table(
        'project-id',
        'dataset_id',
        'table_id',
        bigquery.TableSchema(
            fields=[
                bigquery.TableFieldSchema(
                    name='b', type='BOOLEAN', mode='REQUIRED')
            ]),
        False,
        False)
    self.assertEqual(new_table, 'table_id')

  def test_get_or_create_table_race_condition(self):
    client = mock.Mock()
    client.tables.Insert.side_effect = HttpError(
        response={'status': '409'}, url='', content='')
    client.tables.Get.side_effect = [None, 'table_id']
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
    new_table = wrapper.get_or_create_table(
        'project-id',
        'dataset_id',
        'table_id',
        bigquery.TableSchema(
            fields=[
                bigquery.TableFieldSchema(
                    name='b', type='BOOLEAN', mode='REQUIRED')
            ]),
        False,
        False)
    self.assertEqual(new_table, 'table_id')

  def test_get_or_create_table_intermittent_exception(self):
    client = mock.Mock()
    client.tables.Insert.side_effect = [
        HttpError(response={'status': '408'}, url='', content=''), 'table_id'
    ]
    client.tables.Get.side_effect = [None, 'table_id']
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
    new_table = wrapper.get_or_create_table(
        'project-id',
        'dataset_id',
        'table_id',
        bigquery.TableSchema(
            fields=[
                bigquery.TableFieldSchema(
                    name='b', type='BOOLEAN', mode='REQUIRED')
            ]),
        False,
        False)
    self.assertEqual(new_table, 'table_id')

  @parameterized.expand(['', 'a' * 1025])
  def test_get_or_create_table_invalid_tablename(self, table_id):
    client = mock.Mock()
    client.tables.Get.side_effect = [None]
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

    self.assertRaises(
        ValueError,
        wrapper.get_or_create_table,
        'project-id',
        'dataset_id',
        table_id,
        bigquery.TableSchema(
            fields=[
                bigquery.TableFieldSchema(
                    name='b', type='BOOLEAN', mode='REQUIRED')
            ]),
        False,
        False)

  def test_wait_for_job_returns_true_when_job_is_done(self):
    def make_response(state):
      m = mock.Mock()
      m.status.errorResult = None
      m.status.state = state
      return m

    client, job_ref = mock.Mock(), mock.Mock()
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
    # Return 'DONE' the second time get_job is called.
    wrapper.get_job = mock.Mock(
        side_effect=[make_response('RUNNING'), make_response('DONE')])

    result = wrapper.wait_for_bq_job(
        job_ref, sleep_duration_sec=0, max_retries=5)
    self.assertTrue(result)

  def test_wait_for_job_retries_fail(self):
    client, response, job_ref = mock.Mock(), mock.Mock(), mock.Mock()
    response.status.state = 'RUNNING'
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
    # Return 'RUNNING' response forever.
    wrapper.get_job = lambda *args: response

    with self.assertRaises(RuntimeError) as context:
      wrapper.wait_for_bq_job(job_ref, sleep_duration_sec=0, max_retries=5)
    self.assertEqual(
        'The maximum number of retries has been reached',
        str(context.exception))

  def test_get_query_location(self):
    client = mock.Mock()
    query = """
        SELECT
            av.column1, table.column1
        FROM `dataset.authorized_view` as av
        JOIN `dataset.table` as table ON av.column2 = table.column2
    """
    job = mock.MagicMock(spec=bigquery.Job)
    job.statistics.query.referencedTables = [
        bigquery.TableReference(
            projectId="first_project_id",
            datasetId="first_dataset",
            tableId="table_used_by_authorized_view"),
        bigquery.TableReference(
            projectId="second_project_id",
            datasetId="second_dataset",
            tableId="table"),
    ]
    client.jobs.Insert.return_value = job

    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
    wrapper.get_table_location = mock.Mock(
        side_effect=[
            HttpForbiddenError(response={'status': '404'}, url='', content=''),
            "US"
        ])
    location = wrapper.get_query_location(
        project_id="second_project_id", query=query, use_legacy_sql=False)
    self.assertEqual("US", location)

  def test_perform_load_job_source_mutual_exclusivity(self):
    client = mock.Mock()
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

    # Both source_uri and source_stream specified.
    with self.assertRaises(ValueError):
      wrapper.perform_load_job(
          destination=parse_table_reference('project:dataset.table'),
          job_id='job_id',
          source_uris=['gs://example.com/*'],
          source_stream=io.BytesIO())

    # Neither source_uri nor source_stream specified.
    wrapper.perform_load_job(
        destination=parse_table_reference('project:dataset.table'), job_id='J')

  def test_perform_load_job_with_source_stream(self):
    client = mock.Mock()
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

    wrapper.perform_load_job(
        destination=parse_table_reference('project:dataset.table'),
        job_id='job_id',
        source_stream=io.BytesIO(b'some,data'))

    client.jobs.Insert.assert_called_once()
    upload = client.jobs.Insert.call_args[1]["upload"]
    self.assertEqual(b'some,data', upload.stream.read())

  def test_perform_load_job_with_load_job_id(self):
    client = mock.Mock()
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

    wrapper.perform_load_job(
        destination=parse_table_reference('project:dataset.table'),
        job_id='job_id',
        source_uris=['gs://example.com/*'],
        load_job_project_id='loadId')
    call_args = client.jobs.Insert.call_args
    self.assertEqual('loadId', call_args[0][0].projectId)

  def verify_write_call_metric(
      self, project_id, dataset_id, table_id, status, count):
    """Check if an metric was recorded for the BQ IO write API call."""
    process_wide_monitoring_infos = list(
        MetricsEnvironment.process_wide_container().
        to_runner_api_monitoring_infos(None).values())
    resource = resource_identifiers.BigQueryTable(
        project_id, dataset_id, table_id)
    labels = {
        # TODO(ajamato): Add Ptransform label.
        monitoring_infos.SERVICE_LABEL: 'BigQuery',
        # Refer to any method which writes elements to BigQuery in batches
        # as "BigQueryBatchWrite". I.e. storage API's insertAll, or future
        # APIs introduced.
        monitoring_infos.METHOD_LABEL: 'BigQueryBatchWrite',
        monitoring_infos.RESOURCE_LABEL: resource,
        monitoring_infos.BIGQUERY_PROJECT_ID_LABEL: project_id,
        monitoring_infos.BIGQUERY_DATASET_LABEL: dataset_id,
        monitoring_infos.BIGQUERY_TABLE_LABEL: table_id,
        monitoring_infos.STATUS_LABEL: status,
    }
    expected_mi = monitoring_infos.int64_counter(
        monitoring_infos.API_REQUEST_COUNT_URN, count, labels=labels)
    expected_mi.ClearField("start_time")

    found = False
    for actual_mi in process_wide_monitoring_infos:
      actual_mi.ClearField("start_time")
      if expected_mi == actual_mi:
        found = True
        break
    self.assertTrue(
        found, "Did not find write call metric with status: %s" % status)

  @unittest.skipIf(ClientError is None, 'GCP dependencies are not installed')
  def test_insert_rows_sets_metric_on_failure(self):
    MetricsEnvironment.process_wide_container().reset()
    client = mock.Mock()
    client.insert_rows_json = mock.Mock(
        # Fail a few times, then succeed.
        side_effect=[
            DeadlineExceeded("Deadline Exceeded"),
            InternalServerError("Internal Error"),
            [],
        ])
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
    wrapper.insert_rows("my_project", "my_dataset", "my_table", [])

    # Expect two failing calls, then a success (i.e. two retries).
    self.verify_write_call_metric(
        "my_project", "my_dataset", "my_table", "deadline_exceeded", 1)
    self.verify_write_call_metric(
        "my_project", "my_dataset", "my_table", "internal", 1)
    self.verify_write_call_metric(
        "my_project", "my_dataset", "my_table", "ok", 1)

  @unittest.skipIf(ClientError is None, 'GCP dependencies are not installed')
  def test_start_query_job_priority_configuration(self):
    client = mock.Mock()
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

    query_result = mock.Mock()
    query_result.pageToken = None
    wrapper._get_query_results = mock.Mock(return_value=query_result)

    wrapper._start_query_job(
        "my_project",
        "my_query",
        use_legacy_sql=False,
        flatten_results=False,
        job_id="my_job_id",
        priority=beam.io.BigQueryQueryPriority.BATCH)

    self.assertEqual(
        client.jobs.Insert.call_args[0][0].job.configuration.query.priority,
        'BATCH')

    wrapper._start_query_job(
        "my_project",
        "my_query",
        use_legacy_sql=False,
        flatten_results=False,
        job_id="my_job_id",
        priority=beam.io.BigQueryQueryPriority.INTERACTIVE)

    self.assertEqual(
        client.jobs.Insert.call_args[0][0].job.configuration.query.priority,
        'INTERACTIVE')

  def test_get_temp_table_project_with_temp_table_ref(self):
    """Test _get_temp_table_project returns project from temp_table_ref."""
    client = mock.Mock()
    temp_table_ref = bigquery.TableReference(
        projectId='temp-project',
        datasetId='temp_dataset',
        tableId='temp_table')
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(
        client, temp_table_ref=temp_table_ref)

    result = wrapper._get_temp_table_project('fallback-project')
    self.assertEqual(result, 'temp-project')

  def test_get_temp_table_project_without_temp_table_ref(self):
    """Test _get_temp_table_project returns fallback when no temp_table_ref."""
    client = mock.Mock()
    wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

    result = wrapper._get_temp_table_project('fallback-project')
    self.assertEqual(result, 'fallback-project')


@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class TestRowAsDictJsonCoder(unittest.TestCase):
  def test_row_as_dict(self):
    coder = RowAsDictJsonCoder()
    test_value = {'s': 'abc', 'i': 123, 'f': 123.456, 'b': True}
    self.assertEqual(test_value, coder.decode(coder.encode(test_value)))

  def test_decimal_in_row_as_dict(self):
    decimal_value = decimal.Decimal('123456789.987654321')
    coder = RowAsDictJsonCoder()
    # Bigquery IO uses decimals to represent NUMERIC types.
    # To export to BQ, it's necessary to convert to strings, due to the
    # lower precision of JSON numbers. This means that we can't recognize
    # a NUMERIC when we decode from JSON, thus we match the string here.
    test_value = {'f': 123.456, 'b': True, 'numerico': decimal_value}
    output_value = {'f': 123.456, 'b': True, 'numerico': str(decimal_value)}
    self.assertEqual(output_value, coder.decode(coder.encode(test_value)))

  def json_compliance_exception(self, value):
    with self.assertRaisesRegex(ValueError, re.escape(JSON_COMPLIANCE_ERROR)):
      coder = RowAsDictJsonCoder()
      test_value = {'s': value}
      coder.decode(coder.encode(test_value))

  def test_invalid_json_nan(self):
    self.json_compliance_exception(float('nan'))

  def test_invalid_json_inf(self):
    self.json_compliance_exception(float('inf'))

  def test_invalid_json_neg_inf(self):
    self.json_compliance_exception(float('-inf'))

  def test_ensure_ascii(self):
    coder = RowAsDictJsonCoder()
    test_value = {'s': '🎉'}
    output_value = b'{"s": "\xf0\x9f\x8e\x89"}'

    self.assertEqual(output_value, coder.encode(test_value))


@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class TestJsonRowWriter(unittest.TestCase):
  def test_write_row(self):
    rows = [
        {
            'name': 'beam', 'game': 'dream'
        },
        {
            'name': 'team', 'game': 'cream'
        },
    ]

    with io.BytesIO() as buf:
      # Mock close() so we can access the buffer contents
      # after JsonRowWriter is closed.
      with mock.patch.object(buf, 'close') as mock_close:
        writer = JsonRowWriter(buf)
        for row in rows:
          writer.write(row)
        writer.close()

        mock_close.assert_called_once()

      buf.seek(0)
      read_rows = [
          json.loads(row)
          for row in buf.getvalue().strip().decode('utf-8').split('\n')
      ]

    self.assertEqual(read_rows, rows)


@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class TestAvroRowWriter(unittest.TestCase):
  def test_write_row(self):
    schema = bigquery.TableSchema(
        fields=[
            bigquery.TableFieldSchema(name='stamp', type='TIMESTAMP'),
            bigquery.TableFieldSchema(
                name='number', type='FLOAT', mode='REQUIRED'),
        ])
    stamp = datetime.datetime(2020, 2, 25, 12, 0, 0, tzinfo=pytz.utc)

    with io.BytesIO() as buf:
      # Mock close() so we can access the buffer contents
      # after AvroRowWriter is closed.
      with mock.patch.object(buf, 'close') as mock_close:
        writer = AvroRowWriter(buf, schema)
        writer.write({'stamp': stamp, 'number': float('NaN')})
        writer.close()

        mock_close.assert_called_once()

      buf.seek(0)
      records = [r for r in fastavro.reader(buf)]

    self.assertEqual(len(records), 1)
    self.assertTrue(math.isnan(records[0]['number']))
    self.assertEqual(records[0]['stamp'], stamp)


class TestBQJobNames(unittest.TestCase):
  def test_simple_names(self):
    self.assertEqual(
        "beam_bq_job_EXPORT_beamappjobtest_abcd",
        generate_bq_job_name(
            "beamapp-job-test", "abcd", BigQueryJobTypes.EXPORT))

    self.assertEqual(
        "beam_bq_job_LOAD_beamappjobtest_abcd",
        generate_bq_job_name("beamapp-job-test", "abcd", BigQueryJobTypes.LOAD))

    self.assertEqual(
        "beam_bq_job_QUERY_beamappjobtest_abcd",
        generate_bq_job_name(
            "beamapp-job-test", "abcd", BigQueryJobTypes.QUERY))

    self.assertEqual(
        "beam_bq_job_COPY_beamappjobtest_abcd",
        generate_bq_job_name("beamapp-job-test", "abcd", BigQueryJobTypes.COPY))

  def test_random_in_name(self):
    self.assertEqual(
        "beam_bq_job_COPY_beamappjobtest_abcd_randome",
        generate_bq_job_name(
            "beamapp-job-test", "abcd", BigQueryJobTypes.COPY, "randome"))

  def test_matches_template(self):
    base_pattern = "beam_bq_job_[A-Z]+_[a-z0-9-]+_[a-z0-9-]+(_[a-z0-9-]+)?"
    job_name = generate_bq_job_name(
        "beamapp-job-test", "abcd", BigQueryJobTypes.COPY, "randome")
    self.assertRegex(job_name, base_pattern)

    job_name = generate_bq_job_name(
        "beamapp-job-test", "abcd", BigQueryJobTypes.COPY)
    self.assertRegex(job_name, base_pattern)


@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class TestCheckSchemaEqual(unittest.TestCase):
  def test_simple_schemas(self):
    schema1 = bigquery.TableSchema(fields=[])
    self.assertTrue(check_schema_equal(schema1, schema1))

    schema2 = bigquery.TableSchema(
        fields=[
            bigquery.TableFieldSchema(name="a", mode="NULLABLE", type="INT64")
        ])
    self.assertTrue(check_schema_equal(schema2, schema2))
    self.assertFalse(check_schema_equal(schema1, schema2))

    schema3 = bigquery.TableSchema(
        fields=[
            bigquery.TableFieldSchema(
                name="b",
                mode="REPEATED",
                type="RECORD",
                fields=[
                    bigquery.TableFieldSchema(
                        name="c", mode="REQUIRED", type="BOOL")
                ])
        ])
    self.assertTrue(check_schema_equal(schema3, schema3))
    self.assertFalse(check_schema_equal(schema2, schema3))

  def test_field_order(self):
    """Test that field order is ignored when ignore_field_order=True."""
    schema1 = bigquery.TableSchema(
        fields=[
            bigquery.TableFieldSchema(
                name="a", mode="REQUIRED", type="FLOAT64"),
            bigquery.TableFieldSchema(name="b", mode="REQUIRED", type="INT64"),
        ])

    schema2 = bigquery.TableSchema(fields=list(reversed(schema1.fields)))

    self.assertFalse(check_schema_equal(schema1, schema2))
    self.assertTrue(
        check_schema_equal(schema1, schema2, ignore_field_order=True))

  def test_descriptions(self):
    """
        Test that differences in description are ignored
        when ignore_descriptions=True.
        """
    schema1 = bigquery.TableSchema(
        fields=[
            bigquery.TableFieldSchema(
                name="a",
                mode="REQUIRED",
                type="FLOAT64",
                description="Field A",
            ),
            bigquery.TableFieldSchema(
                name="b",
                mode="REQUIRED",
                type="INT64",
            ),
        ])

    schema2 = bigquery.TableSchema(
        fields=[
            bigquery.TableFieldSchema(
                name="a",
                mode="REQUIRED",
                type="FLOAT64",
                description="Field A is for Apple"),
            bigquery.TableFieldSchema(
                name="b",
                mode="REQUIRED",
                type="INT64",
                description="Field B",
            ),
        ])

    self.assertFalse(check_schema_equal(schema1, schema2))
    self.assertTrue(
        check_schema_equal(schema1, schema2, ignore_descriptions=True))


@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class TestBeamRowFromDict(unittest.TestCase):
  DICT_ROW = {
      "str": "a",
      "bool": True,
      "bytes": b'a',
      "int": 1,
      "float": 0.1,
      "numeric": decimal.Decimal("1.11"),
      "timestamp": Timestamp(1000, 100)
  }

  def get_schema_fields_with_mode(self, mode):
    return [{
        "name": "str", "type": "STRING", "mode": mode
    }, {
        "name": "bool", "type": "boolean", "mode": mode
    }, {
        "name": "bytes", "type": "BYTES", "mode": mode
    }, {
        "name": "int", "type": "INTEGER", "mode": mode
    }, {
        "name": "float", "type": "Float", "mode": mode
    }, {
        "name": "numeric", "type": "NUMERIC", "mode": mode
    }, {
        "name": "timestamp", "type": "TIMESTAMP", "mode": mode
    }]

  def test_dict_to_beam_row_all_types_required(self):
    schema = {"fields": self.get_schema_fields_with_mode("REQUIRED")}
    expected_beam_row = beam.Row(
        str="a",
        bool=True,
        bytes=b'a',
        int=1,
        float=0.1,
        numeric=decimal.Decimal("1.11"),
        timestamp=Timestamp(1000, 100))

    self.assertEqual(
        expected_beam_row, beam_row_from_dict(self.DICT_ROW, schema))

  def test_dict_to_beam_row_all_types_repeated(self):
    schema = {"fields": self.get_schema_fields_with_mode("REPEATED")}
    dict_row = {
        "str": ["a", "b"],
        "bool": [True, False],
        "bytes": [b'a', b'b'],
        "int": [1, 2],
        "float": [0.1, 0.2],
        "numeric": [decimal.Decimal("1.11"), decimal.Decimal("2.22")],
        "timestamp": [Timestamp(1000, 100), Timestamp(2000, 200)]
    }

    expected_beam_row = beam.Row(
        str=["a", "b"],
        bool=[True, False],
        bytes=[b'a', b'b'],
        int=[1, 2],
        float=[0.1, 0.2],
        numeric=[decimal.Decimal("1.11"), decimal.Decimal("2.22")],
        timestamp=[Timestamp(1000, 100), Timestamp(2000, 200)])

    self.assertEqual(expected_beam_row, beam_row_from_dict(dict_row, schema))

  def test_dict_to_beam_row_all_types_nullable(self):
    schema_fields_with_nested = [{
        "name": "nested_record",
        "type": "record",
        "mode": "repeated",
        "fields": self.get_schema_fields_with_mode("nullable")
    }]
    schema_fields_with_nested.extend(
        self.get_schema_fields_with_mode("nullable"))
    schema = {"fields": schema_fields_with_nested}
    dict_row = {k: None for k in self.DICT_ROW}

    # input dict row with missing nullable fields should still yield a full
    # Beam Row
    del dict_row['str']
    del dict_row['bool']

    expected_beam_row = beam.Row(
        nested_record=None,
        str=None,
        bool=None,
        bytes=None,
        int=None,
        float=None,
        numeric=None,
        timestamp=None)

    self.assertEqual(expected_beam_row, beam_row_from_dict(dict_row, schema))

  def test_dict_to_beam_row_nested_record(self):
    schema_fields_with_nested = [{
        "name": "nested_record",
        "type": "record",
        "fields": self.get_schema_fields_with_mode("required")
    }]
    schema_fields_with_nested.extend(
        self.get_schema_fields_with_mode("required"))
    schema = {"fields": schema_fields_with_nested}

    dict_row = {
        "nested_record": self.DICT_ROW,
        "str": "a",
        "bool": True,
        "bytes": b'a',
        "int": 1,
        "float": 0.1,
        "numeric": decimal.Decimal("1.11"),
        "timestamp": Timestamp(1000, 100)
    }
    expected_beam_row = beam.Row(
        nested_record=beam.Row(
            str="a",
            bool=True,
            bytes=b'a',
            int=1,
            float=0.1,
            numeric=decimal.Decimal("1.11"),
            timestamp=Timestamp(1000, 100)),
        str="a",
        bool=True,
        bytes=b'a',
        int=1,
        float=0.1,
        numeric=decimal.Decimal("1.11"),
        timestamp=Timestamp(1000, 100))

    self.assertEqual(expected_beam_row, beam_row_from_dict(dict_row, schema))

  def test_dict_to_beam_row_repeated_nested_record(self):
    schema_fields_with_repeated_nested_record = [{
        "name": "nested_repeated_record",
        "type": "record",
        "mode": "repeated",
        "fields": self.get_schema_fields_with_mode("required")
    }]
    schema = {"fields": schema_fields_with_repeated_nested_record}

    dict_row = {
        "nested_repeated_record": [self.DICT_ROW, self.DICT_ROW, self.DICT_ROW],
    }

    beam_row = beam.Row(
        str="a",
        bool=True,
        bytes=b'a',
        int=1,
        float=0.1,
        numeric=decimal.Decimal("1.11"),
        timestamp=Timestamp(1000, 100))
    expected_beam_row = beam.Row(
        nested_repeated_record=[beam_row, beam_row, beam_row])

    self.assertEqual(expected_beam_row, beam_row_from_dict(dict_row, schema))


@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class TestBeamTypehintFromSchema(unittest.TestCase):
  EXPECTED_TYPEHINTS = [("str", str), ("bool", bool), ("bytes", bytes),
                        ("int", np.int64), ("float", np.float64),
                        ("numeric", decimal.Decimal), ("timestamp", Timestamp)]

  def get_schema_fields_with_mode(self, mode):
    return [{
        "name": "str", "type": "STRING", "mode": mode
    }, {
        "name": "bool", "type": "boolean", "mode": mode
    }, {
        "name": "bytes", "type": "BYTES", "mode": mode
    }, {
        "name": "int", "type": "INTEGER", "mode": mode
    }, {
        "name": "float", "type": "Float", "mode": mode
    }, {
        "name": "numeric", "type": "NUMERIC", "mode": mode
    }, {
        "name": "timestamp", "type": "TIMESTAMP", "mode": mode
    }]

  def test_typehints_from_required_schema(self):
    schema = {"fields": self.get_schema_fields_with_mode("required")}
    typehints = get_beam_typehints_from_tableschema(schema)

    self.assertEqual(typehints, self.EXPECTED_TYPEHINTS)

  def test_typehints_from_repeated_schema(self):
    schema = {"fields": self.get_schema_fields_with_mode("repeated")}
    typehints = get_beam_typehints_from_tableschema(schema)

    expected_repeated_typehints = [(name, Sequence[type])
                                   for name, type in self.EXPECTED_TYPEHINTS]

    self.assertEqual(typehints, expected_repeated_typehints)

  def test_typehints_from_nullable_schema(self):
    schema = {"fields": self.get_schema_fields_with_mode("nullable")}
    typehints = get_beam_typehints_from_tableschema(schema)

    expected_nullable_typehints = [(name, Optional[type])
                                   for name, type in self.EXPECTED_TYPEHINTS]

    self.assertEqual(typehints, expected_nullable_typehints)

  def test_typehints_from_schema_with_struct(self):
    schema = {
        "fields": [{
            "name": "record",
            "type": "record",
            "mode": "required",
            "fields": self.get_schema_fields_with_mode("required")
        }]
    }
    typehints = get_beam_typehints_from_tableschema(schema)

    expected_typehints = [
        ("record", RowTypeConstraint.from_fields(self.EXPECTED_TYPEHINTS))
    ]

    self.assertEqual(typehints, expected_typehints)

  def test_typehints_from_schema_with_repeated_struct(self):
    schema = {
        "fields": [{
            "name": "record",
            "type": "record",
            "mode": "repeated",
            "fields": self.get_schema_fields_with_mode("required")
        }]
    }
    typehints = get_beam_typehints_from_tableschema(schema)

    expected_typehints = [(
        "record",
        Sequence[RowTypeConstraint.from_fields(self.EXPECTED_TYPEHINTS)])]

    self.assertEqual(typehints, expected_typehints)


@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class TestGeographyTypeSupport(unittest.TestCase):
  """Tests for GEOGRAPHY data type support in BigQuery."""
  def test_geography_in_bigquery_type_mapping(self):
    """Test that GEOGRAPHY is properly mapped in type mapping."""
    from apache_beam.io.gcp.bigquery_tools import BIGQUERY_TYPE_TO_PYTHON_TYPE

    self.assertIn("GEOGRAPHY", BIGQUERY_TYPE_TO_PYTHON_TYPE)
    self.assertEqual(BIGQUERY_TYPE_TO_PYTHON_TYPE["GEOGRAPHY"], str)

  def test_geography_field_conversion(self):
    """Test that GEOGRAPHY fields are converted correctly."""
    from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper

    # Create a mock field with GEOGRAPHY type
    field = bigquery.TableFieldSchema()
    field.type = 'GEOGRAPHY'
    field.name = 'location'
    field.mode = 'NULLABLE'

    wrapper = BigQueryWrapper(client=mock.Mock())

    # Test various WKT formats
    test_cases = [
        "POINT(30 10)",
        "LINESTRING(30 10, 10 30, 40 40)",
        "POLYGON((30 10, 40 40, 20 40, 10 20, 30 10))",
        "MULTIPOINT((10 40), (40 30), (20 20), (30 10))",
        "GEOMETRYCOLLECTION(POINT(4 6),LINESTRING(4 6,7 10))"
    ]

    for wkt_value in test_cases:
      result = wrapper._convert_cell_value_to_dict(wkt_value, field)
      self.assertEqual(result, wkt_value)
      self.assertIsInstance(result, str)

  def test_geography_typehints_from_schema(self):
    """Test that GEOGRAPHY fields generate correct type hints."""
    schema = {
        "fields": [{
            "name": "location", "type": "GEOGRAPHY", "mode": "REQUIRED"
        },
                   {
                       "name": "optional_location",
                       "type": "GEOGRAPHY",
                       "mode": "NULLABLE"
                   }, {
                       "name": "locations",
                       "type": "GEOGRAPHY",
                       "mode": "REPEATED"
                   }]
    }

    typehints = get_beam_typehints_from_tableschema(schema)

    expected_typehints = [("location", str),
                          ("optional_location", Optional[str]),
                          ("locations", Sequence[str])]

    self.assertEqual(typehints, expected_typehints)

  def test_geography_beam_row_conversion(self):
    """Test converting dictionary with GEOGRAPHY to Beam Row."""
    schema = {
        "fields": [{
            "name": "id", "type": "INTEGER", "mode": "REQUIRED"
        }, {
            "name": "location", "type": "GEOGRAPHY", "mode": "NULLABLE"
        }, {
            "name": "name", "type": "STRING", "mode": "REQUIRED"
        }]
    }

    row_dict = {"id": 1, "location": "POINT(30 10)", "name": "Test Location"}

    beam_row = beam_row_from_dict(row_dict, schema)

    self.assertEqual(beam_row.id, 1)
    self.assertEqual(beam_row.location, "POINT(30 10)")
    self.assertEqual(beam_row.name, "Test Location")

  def test_geography_beam_row_conversion_with_null(self):
    """Test converting dictionary with null GEOGRAPHY to Beam Row."""
    schema = {
        "fields": [{
            "name": "id", "type": "INTEGER", "mode": "REQUIRED"
        }, {
            "name": "location", "type": "GEOGRAPHY", "mode": "NULLABLE"
        }]
    }

    row_dict = {"id": 1, "location": None}

    beam_row = beam_row_from_dict(row_dict, schema)

    self.assertEqual(beam_row.id, 1)
    self.assertIsNone(beam_row.location)

  def test_geography_beam_row_conversion_repeated(self):
    """Test converting dictionary with repeated GEOGRAPHY to Beam Row."""
    schema = {
        "fields": [{
            "name": "id", "type": "INTEGER", "mode": "REQUIRED"
        }, {
            "name": "locations", "type": "GEOGRAPHY", "mode": "REPEATED"
        }]
    }

    row_dict = {
        "id": 1,
        "locations": ["POINT(30 10)", "POINT(40 20)", "LINESTRING(0 0, 1 1)"]
    }

    beam_row = beam_row_from_dict(row_dict, schema)

    self.assertEqual(beam_row.id, 1)
    self.assertEqual(len(beam_row.locations), 3)
    self.assertEqual(beam_row.locations[0], "POINT(30 10)")
    self.assertEqual(beam_row.locations[1], "POINT(40 20)")
    self.assertEqual(beam_row.locations[2], "LINESTRING(0 0, 1 1)")

  def test_geography_json_encoding(self):
    """Test that GEOGRAPHY values are properly JSON encoded."""
    coder = RowAsDictJsonCoder()

    row_with_geography = {"id": 1, "location": "POINT(30 10)", "name": "Test"}

    encoded = coder.encode(row_with_geography)
    decoded = coder.decode(encoded)

    self.assertEqual(decoded["location"], "POINT(30 10)")
    self.assertIsInstance(decoded["location"], str)

  def test_geography_with_special_characters(self):
    """Test GEOGRAPHY values with special characters and geometries."""
    from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper

    field = bigquery.TableFieldSchema()
    field.type = 'GEOGRAPHY'
    field.name = 'complex_geo'
    field.mode = 'NULLABLE'

    wrapper = BigQueryWrapper(client=mock.Mock())

    # Test complex WKT with various coordinate systems and precision
    complex_wkt = (
        "POLYGON((-122.4194 37.7749, -122.4094 37.7849, "
        "-122.3994 37.7749, -122.4194 37.7749))")

    result = wrapper._convert_cell_value_to_dict(complex_wkt, field)
    self.assertEqual(result, complex_wkt)
    self.assertIsInstance(result, str)


if __name__ == '__main__':
  logging.getLogger().setLevel(logging.INFO)
  unittest.main()
